# -*- coding: utf-8 -*-
#
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# Using this computer program means that you agree to the terms 
# in the LICENSE file included with this software distribution. 
# Any use not explicitly granted by the LICENSE is prohibited.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# For comments or questions, please email us at deca@tue.mpg.de
# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de

import torch

''' Rotation Converter
Representations: 
	euler angle(3), angle axis(3), rotation matrix(3x3), quaternion(4), continous repre
Ref: 
	https://kornia.readthedocs.io/en/v0.1.2/_modules/torchgeometry/core/conversions.html#
	smplx/lbs
'''

pi = torch.Tensor([3.14159265358979323846])
def rad2deg(tensor):
	"""Function that converts angles from radians to degrees.

	See :class:`~torchgeometry.RadToDeg` for details.

	Args:
		tensor (Tensor): Tensor of arbitrary shape.

	Returns:
		Tensor: Tensor with same shape as input.

	Example:
		>>> input = tgm.pi * torch.rand(1, 3, 3)
		>>> output = tgm.rad2deg(input)
	"""
	if not torch.is_tensor(tensor):
		raise TypeError("Input type is not a torch.Tensor. Got {}"
						.format(type(tensor)))

	return 180. * tensor / pi.to(tensor.device).type(tensor.dtype)

def deg2rad(tensor):
	"""Function that converts angles from degrees to radians.

	See :class:`~torchgeometry.DegToRad` for details.

	Args:
		tensor (Tensor): Tensor of arbitrary shape.

	Returns:
		Tensor: Tensor with same shape as input.

	Examples::

		>>> input = 360. * torch.rand(1, 3, 3)
		>>> output = tgm.deg2rad(input)
	"""
	if not torch.is_tensor(tensor):
		raise TypeError("Input type is not a torch.Tensor. Got {}"
						.format(type(tensor)))

	return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.

######### to quaternion
def euler_to_quaternion(r):
	x = r[..., 0]
	y = r[..., 1]
	z = r[..., 2]

	z = z/2.0
	y = y/2.0
	x = x/2.0
	cz = torch.cos(z)
	sz = torch.sin(z)
	cy = torch.cos(y)
	sy = torch.sin(y)
	cx = torch.cos(x)
	sx = torch.sin(x)
	quaternion = torch.zeros_like(r.repeat(1,2))[..., :4].to(r.device)
	quaternion[..., 0] += cx*cy*cz - sx*sy*sz
	quaternion[..., 1] += cx*sy*sz + cy*cz*sx
	quaternion[..., 2] += cx*cz*sy - sx*cy*sz
	quaternion[..., 3] += cx*cy*sz + sx*cz*sy
	return quaternion

def rotation_matrix_to_quaternion(rotation_matrix, eps=1e-6):
	"""Convert 3x4 rotation matrix to 4d quaternion vector

	This algorithm is based on algorithm described in
	https://github.com/KieranWynn/pyquaternion/blob/master/pyquaternion/quaternion.py#L201

	Args:
		rotation_matrix (Tensor): the rotation matrix to convert.

	Return:
		Tensor: the rotation in quaternion

	Shape:
		- Input: :math:`(N, 3, 4)`
		- Output: :math:`(N, 4)`

	Example:
		>>> input = torch.rand(4, 3, 4)  # Nx3x4
		>>> output = tgm.rotation_matrix_to_quaternion(input)  # Nx4
	"""
	if not torch.is_tensor(rotation_matrix):
		raise TypeError("Input type is not a torch.Tensor. Got {}".format(
			type(rotation_matrix)))

	if len(rotation_matrix.shape) > 3:
		raise ValueError(
			"Input size must be a three dimensional tensor. Got {}".format(
				rotation_matrix.shape))
	# if not rotation_matrix.shape[-2:] == (3, 4):
	#     raise ValueError(
	#         "Input size must be a N x 3 x 4  tensor. Got {}".format(
	#             rotation_matrix.shape))

	rmat_t = torch.transpose(rotation_matrix, 1, 2)

	mask_d2 = rmat_t[:, 2, 2] < eps

	mask_d0_d1 = rmat_t[:, 0, 0] > rmat_t[:, 1, 1]
	mask_d0_nd1 = rmat_t[:, 0, 0] < -rmat_t[:, 1, 1]

	t0 = 1 + rmat_t[:, 0, 0] - rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
	q0 = torch.stack([rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
					  t0, rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
					  rmat_t[:, 2, 0] + rmat_t[:, 0, 2]], -1)
	t0_rep = t0.repeat(4, 1).t()

	t1 = 1 - rmat_t[:, 0, 0] + rmat_t[:, 1, 1] - rmat_t[:, 2, 2]
	q1 = torch.stack([rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
					  rmat_t[:, 0, 1] + rmat_t[:, 1, 0],
					  t1, rmat_t[:, 1, 2] + rmat_t[:, 2, 1]], -1)
	t1_rep = t1.repeat(4, 1).t()

	t2 = 1 - rmat_t[:, 0, 0] - rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
	q2 = torch.stack([rmat_t[:, 0, 1] - rmat_t[:, 1, 0],
					  rmat_t[:, 2, 0] + rmat_t[:, 0, 2],
					  rmat_t[:, 1, 2] + rmat_t[:, 2, 1], t2], -1)
	t2_rep = t2.repeat(4, 1).t()

	t3 = 1 + rmat_t[:, 0, 0] + rmat_t[:, 1, 1] + rmat_t[:, 2, 2]
	q3 = torch.stack([t3, rmat_t[:, 1, 2] - rmat_t[:, 2, 1],
					  rmat_t[:, 2, 0] - rmat_t[:, 0, 2],
					  rmat_t[:, 0, 1] - rmat_t[:, 1, 0]], -1)
	t3_rep = t3.repeat(4, 1).t()

	mask_c0 = mask_d2 * mask_d0_d1.float()
	mask_c1 = mask_d2 * (1 - mask_d0_d1.float())
	mask_c2 = (1 - mask_d2.float()) * mask_d0_nd1
	mask_c3 = (1 - mask_d2.float()) * (1 - mask_d0_nd1.float())
	mask_c0 = mask_c0.view(-1, 1).type_as(q0)
	mask_c1 = mask_c1.view(-1, 1).type_as(q1)
	mask_c2 = mask_c2.view(-1, 1).type_as(q2)
	mask_c3 = mask_c3.view(-1, 1).type_as(q3)

	q = q0 * mask_c0 + q1 * mask_c1 + q2 * mask_c2 + q3 * mask_c3
	q /= torch.sqrt(t0_rep * mask_c0 + t1_rep * mask_c1 +  # noqa
					t2_rep * mask_c2 + t3_rep * mask_c3)  # noqa
	q *= 0.5
	return q

# def angle_axis_to_quaternion(theta):
#     batch_size = theta.shape[0]
#     l1norm = torch.norm(theta + 1e-8, p=2, dim=1)
#     angle = torch.unsqueeze(l1norm, -1)
#     normalized = torch.div(theta, angle)
#     angle = angle * 0.5
#     v_cos = torch.cos(angle)
#     v_sin = torch.sin(angle)
#     quat = torch.cat([v_cos, v_sin * normalized], dim=1)
#     return quat

def angle_axis_to_quaternion(angle_axis: torch.Tensor) -> torch.Tensor:
	"""Convert an angle axis to a quaternion.

	Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h

	Args:
		angle_axis (torch.Tensor): tensor with angle axis.

	Return:
		torch.Tensor: tensor with quaternion.

	Shape:
		- Input: :math:`(*, 3)` where `*` means, any number of dimensions
		- Output: :math:`(*, 4)`

	Example:
		>>> angle_axis = torch.rand(2, 4)  # Nx4
		>>> quaternion = tgm.angle_axis_to_quaternion(angle_axis)  # Nx3
	"""
	if not torch.is_tensor(angle_axis):
		raise TypeError("Input type is not a torch.Tensor. Got {}".format(
			type(angle_axis)))

	if not angle_axis.shape[-1] == 3:
		raise ValueError("Input must be a tensor of shape Nx3 or 3. Got {}"
						 .format(angle_axis.shape))
	# unpack input and compute conversion
	a0: torch.Tensor = angle_axis[..., 0:1]
	a1: torch.Tensor = angle_axis[..., 1:2]
	a2: torch.Tensor = angle_axis[..., 2:3]
	theta_squared: torch.Tensor = a0 * a0 + a1 * a1 + a2 * a2

	theta: torch.Tensor = torch.sqrt(theta_squared)
	half_theta: torch.Tensor = theta * 0.5

	mask: torch.Tensor = theta_squared > 0.0
	ones: torch.Tensor = torch.ones_like(half_theta)

	k_neg: torch.Tensor = 0.5 * ones
	k_pos: torch.Tensor = torch.sin(half_theta) / theta
	k: torch.Tensor = torch.where(mask, k_pos, k_neg)
	w: torch.Tensor = torch.where(mask, torch.cos(half_theta), ones)

	quaternion: torch.Tensor = torch.zeros_like(angle_axis)
	quaternion[..., 0:1] += a0 * k
	quaternion[..., 1:2] += a1 * k
	quaternion[..., 2:3] += a2 * k

	# print(quaternion)
	return torch.cat([w, quaternion], dim=-1)

#### quaternion to
def quaternion_to_rotation_matrix(quat):
	"""Convert quaternion coefficients to rotation matrix.
	Args:
		quat: size = [B, 4] 4 <===>(w, x, y, z)
	Returns:
		Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
	"""
	norm_quat = quat
	norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
	w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, 2], norm_quat[:, 3]

	B = quat.size(0)

	w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
	wx, wy, wz = w * x, w * y, w * z
	xy, xz, yz = x * y, x * z, y * z

	rotMat = torch.stack([w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz,
						  2 * wz + 2 * xy, w2 - x2 + y2 - z2, 2 * yz - 2 * wx,
						  2 * xz - 2 * wy, 2 * wx + 2 * yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3)
	return rotMat

def quaternion_to_angle_axis(quaternion: torch.Tensor):
	"""Convert quaternion vector to angle axis of rotation. TODO: CORRECT

	Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h

	Args:
		quaternion (torch.Tensor): tensor with quaternions.

	Return:
		torch.Tensor: tensor with angle axis of rotation.

	Shape:
		- Input: :math:`(*, 4)` where `*` means, any number of dimensions
		- Output: :math:`(*, 3)`

	Example:
		>>> quaternion = torch.rand(2, 4)  # Nx4
		>>> angle_axis = tgm.quaternion_to_angle_axis(quaternion)  # Nx3
	"""
	if not torch.is_tensor(quaternion):
		raise TypeError("Input type is not a torch.Tensor. Got {}".format(
			type(quaternion)))

	if not quaternion.shape[-1] == 4:
		raise ValueError("Input must be a tensor of shape Nx4 or 4. Got {}"
						 .format(quaternion.shape))
	# unpack input and compute conversion
	q1: torch.Tensor = quaternion[..., 1]
	q2: torch.Tensor = quaternion[..., 2]
	q3: torch.Tensor = quaternion[..., 3]
	sin_squared_theta: torch.Tensor = q1 * q1 + q2 * q2 + q3 * q3

	sin_theta: torch.Tensor = torch.sqrt(sin_squared_theta)
	cos_theta: torch.Tensor = quaternion[..., 0]
	two_theta: torch.Tensor = 2.0 * torch.where(
		cos_theta < 0.0,
		torch.atan2(-sin_theta, -cos_theta),
		torch.atan2(sin_theta, cos_theta))

	k_pos: torch.Tensor = two_theta / sin_theta
	k_neg: torch.Tensor = 2.0 * torch.ones_like(sin_theta).to(quaternion.device)
	k: torch.Tensor = torch.where(sin_squared_theta > 0.0, k_pos, k_neg)

	angle_axis: torch.Tensor = torch.zeros_like(quaternion).to(quaternion.device)[..., :3]
	angle_axis[..., 0] += q1 * k
	angle_axis[..., 1] += q2 * k
	angle_axis[..., 2] += q3 * k
	return angle_axis

#### batch converter
def batch_euler2axis(r):
	return quaternion_to_angle_axis(euler_to_quaternion(r))

def batch_euler2matrix(r):
	return quaternion_to_rotation_matrix(euler_to_quaternion(r))

def batch_matrix2euler(rot_mats):
	# Calculates rotation matrix to euler angles
	# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
	### only y?
	# TODO:
	# sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
	#                 rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
	# return torch.atan2(-rot_mats[:, 2, 0], sy)
	batch_index = 0
	yaw = torch.zeros(rot_mats.shape[0],1)
	pitch = torch.zeros(rot_mats.shape[0],1)
	roll = torch.zeros(rot_mats.shape[0],1)
	for R in rot_mats:

		if R[2, 0] > 0.998:
			z = 0
			x = np.pi / 2
			y = z + atan2(-R[0, 1], -R[0, 2])
		elif R[2, 0] < -0.998:
			z = 0
			x = -np.pi / 2
			y = -z + torch.atan2(R[0, 1], R[0, 2])
		else:
			x = torch.asin(R[2, 0])
			y = torch.atan2(R[2, 1] / torch.cos(x), R[2, 2] / torch.cos(x))
			z = torch.atan2(R[1, 0] / torch.cos(x), R[0, 0] / torch.cos(x))
			
		yaw[batch_index] = x
		pitch[batch_index] = y
		roll[batch_index] = z
		batch_index = batch_index + 1
	angles = torch.zeros(1, 3)
	angles[:,0] = x
	angles[:,1] = y
	angles[:,2] = z
	return angles

def batch_matrix2axis(rot_mats):
	return quaternion_to_angle_axis(rotation_matrix_to_quaternion(rot_mats))

def batch_axis2matrix(theta):
	# angle axis to rotation matrix
	# theta N x 3
	# return quat2mat(quat)
	# batch_rodrigues
	return quaternion_to_rotation_matrix(angle_axis_to_quaternion(theta))

def batch_axis2euler(theta):
	return batch_matrix2euler(batch_axis2matrix(theta))



def batch_orth_proj(X, camera):
	'''
		X is N x num_pquaternion_to_angle_axisoints x 3
	'''
	camera = camera.clone().view(-1, 1, 3)
	X_trans = X[:, :, :2] + camera[:, :, 1:]
	X_trans = torch.cat([X_trans, X[:,:,2:]], 2)
	Xn = (camera[:, :, 0:1] * X_trans)
	return Xn

def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
	'''  same as batch_matrix2axis
	Calculates the rotation matrices for a batch of rotation vectors
		Parameters
		----------
		rot_vecs: torch.tensor Nx3
			array of N axis-angle vectors
		Returns
		-------
		R: torch.tensor Nx3x3
			The rotation matrices for the given axis-angle parameters
	'''

	batch_size = rot_vecs.shape[0]
	device = rot_vecs.device

	angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
	rot_dir = rot_vecs / angle

	cos = torch.unsqueeze(torch.cos(angle), dim=1)
	sin = torch.unsqueeze(torch.sin(angle), dim=1)

	# Bx1 arrays
	rx, ry, rz = torch.split(rot_dir, 1, dim=1)
	K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)

	zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
	K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
		.view((batch_size, 3, 3))

	ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
	rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
	return rot_mat
